import sys
sys.path.append('/vinbrain/huyta/domainadaptation/ssda/')
from ssda.learner import Learner, auxs
from ssda.model import AdaptationModel
from ssda.dataset import AdaptationDataset
n_aux_classes = 5
auxs = auxs
aux_labels = [90, 45, 0, -45, -90]
img_size = 256
import torch
from torch import nn
sys.path.append('/vinbrain/huyta/')
from COVID_wave_4.utils.layers import PooledSelfAttention2d, AdaptiveConcatPool
from COVID_wave_4.utils.fastai_utils import create_effnet
# modified version
class SAPoolClassifier(nn.Sequential):
def __init__(self, in_c, out_c):
super(SAPoolClassifier, self).__init__()
self.SA = PooledSelfAttention2d(in_c)
self.pool = AdaptiveConcatPool()
self.fc = nn.Sequential(
nn.Flatten(),
nn.BatchNorm1d(in_c*2), nn.Dropout(0.25), nn.Linear(in_c*2, 512), nn.ReLU(),
nn.BatchNorm1d(512), nn.Dropout(0.5), nn.Linear(512, out_c)
)
model = nn.Sequential(create_effnet(False, 'tf_efficientnet_b5_ns'), SAPoolClassifier(2048, 5))
states = torch.load('/vinbrain/huyta/COVID_wave_4/train_notebooks/models/train_all_512_no_fastai.pth')
model.load_state_dict(states, strict=False)
backbone = nn.Sequential(model[0], model[1].SA, model[1].pool, nn.Flatten())
classifier = model[1].fc
adapt_model = AdaptationModel(backbone=backbone, classifier=classifier, n_aux_classes=n_aux_classes)
adapt_model = nn.DataParallel(adapt_model, device_ids=[0, 1])
adapt_model.cuda();
import pandas as pd
data_path = '/u01/data/COVID_Data_Relabel/data/'
label_cols_list = ['Covid', 'Airspace_Opacity', 'Consolidation', 'Atelectasis', 'Lung_Lesion']
source_df_train = pd.read_csv('/vinbrain/huyta/COVID_wave_4/csv/train_clean.csv')
source_df_valid = pd.read_csv('/vinbrain/huyta/COVID_wave_4/csv/valid_clean.csv')
source_df_train['Images'] = data_path + source_df_train['Images']
source_df_valid['Images'] = data_path + source_df_valid['Images']
from pathlib import Path
from PIL import Image
target_imgs = Path('/vinbrain/huyta/COVID_wave_4/data/bacgiang/all_imgs/').rglob('*.png')
target_imgs = [str(c) for c in target_imgs if 'ipynb' not in str(c)][1:]
target_df = pd.DataFrame({'Images': target_imgs})
Image.open(target_df.at[0, 'Images'])
from torchvision.transforms import *
train_augs = Compose([
ToPILImage(),
RandomResizedCrop(img_size, (0.8, 1.25)),
RandomApply([ColorJitter(brightness=(0.2), contrast=(0.85, 1.0))], p=0.4),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
valid_augs = Compose([
ToPILImage(),
Resize(img_size),
CenterCrop(img_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_ds = AdaptationDataset(auxs, source_df_train, target_df, train_augs, label_cols_list)
valid_ds = AdaptationDataset(auxs, source_df_valid, target_df, valid_augs, label_cols_list)
from TedAI.tedai import make_imgs
import numpy as np
sample_src, sample_tgt = train_ds[np.random.randint(0, len(train_ds))]
x_src, xa_src = sample_src['x'], sample_src['xa']
x_tgt, xa_tgt = sample_tgt['x'], sample_tgt['xa']
a_src, a_tgt = sample_src['a'], sample_tgt['a']
batch = torch.stack([x_src, xa_src, x_tgt, xa_tgt])
print(' '*40, 'source: ', aux_labels[a_src.item()], ' '*50, 'target: ', aux_labels[a_tgt.item()])
make_imgs(batch, 4, plot=False)
sample_src, sample_tgt = valid_ds[np.random.randint(0, len(valid_ds))]
x_src, xa_src = sample_src['x'], sample_src['xa']
x_tgt, xa_tgt = sample_tgt['x'], sample_tgt['xa']
a_src, a_tgt = sample_src['a'], sample_tgt['a']
batch = torch.stack([x_src, xa_src, x_tgt, xa_tgt])
print(' '*40, 'source: ', aux_labels[a_src.item()], ' '*50, 'target: ', aux_labels[a_tgt.item()])
make_imgs(batch, 4, plot=False)
from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)
dataloaders = [train_dl, valid_dl]
W = {
'w_src_aux': 0.5,
'w_src_kld': 0.5,
'w_tgt_aux': 0.5,
'w_tgt_kld': 0.5,
'w_tgt_ent': 0.1,
}
learn = Learner(model=adapt_model, dataloaders=dataloaders, name='first_run', W=W)
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
n_epochs = 5
optimizer = Adam(adapt_model.parameters(), lr=1e-3)
scheduler = OneCycleLR(optimizer, max_lr=1e-3, total_steps=len(dataloaders[0])*n_epochs)
learn.fit(n_epochs, optimizer, scheduler)